Skip to content

[Gemma 4] Add multimodal support (apply_liger_kernel_to_gemma4 for Gemma4ForConditionalGeneration)#1203

Merged
Mecoli1219 merged 3 commits into
linkedin:mainfrom
dvdimitrov13:feat/gemma4-multimodal
May 20, 2026
Merged

[Gemma 4] Add multimodal support (apply_liger_kernel_to_gemma4 for Gemma4ForConditionalGeneration)#1203
Mecoli1219 merged 3 commits into
linkedin:mainfrom
dvdimitrov13:feat/gemma4-multimodal

Conversation

@dvdimitrov13
Copy link
Copy Markdown
Contributor

@dvdimitrov13 dvdimitrov13 commented Apr 27, 2026

Summary

Follow-up to #1196 (the text path) — adds apply_liger_kernel_to_gemma4 for Gemma4ForConditionalGeneration (multimodal class; includes E2B / E4B / E4B-it which are loaded by AutoModelForCausalLM as the multimodal class even when only text is being trained).

Closes the multimodal half of #1186.

Why

Gemma 4's text vocab is 262,144. Without FLCE the (B, T, V) bf16 logits tensor is ~17 GB at T=8192 (and ~34 GB once the loss path upcasts to fp32 for cross-entropy), which OOMs even 96 GB cards on Gemma4ForConditionalGeneration SFT — the OOM that originally motivated #1186. Routing loss through LigerForCausalLMLoss materializes only the loss scalar.

Shape

A single unified entry point that dispatches on class, per @Mecoli1219's preference in #1186:

  • apply_liger_kernel_to_gemma4(model=Gemma4ForConditionalGeneration_instance) — multimodal path. Class-level RMSNorm + GeGLU swaps via apply_liger_kernel_to_gemma4_text, FLCE forward via the new multimodal_forward, recurses into model.model.language_model for instance-level patches.
  • apply_liger_kernel_to_gemma4(model=Gemma4ForCausalLM_instance) — routes to apply_liger_kernel_to_gemma4_text for backwards compatibility, so the same entry point works for either shape.
  • Registry: adds "gemma4": apply_liger_kernel_to_gemma4 alongside the existing "gemma4_text".

Drive-by fixes

Two isinstance(model, tuple_with_None_filter) sites (one in our new dispatcher, one in the existing apply_liger_kernel_to_gemma4_text) raised TypeError: isinstance() arg 2 must be a type, a tuple of types, or a union when called under with patch("transformers.models.gemma4.modeling_gemma4"): getattr(MagicMock_module, "Gemma4TextForCausalLM", None) returns a MagicMock (not None) because MagicMock auto-creates attributes, so the cls is not None filter let it slip into the isinstance tuple. The text-path version was dormant — its existing test passes a Gemma4ForCausalLM which short-circuits the isinstance match before reaching the bad entry — but the multimodal recursive call into the text path passes a Gemma4TextModel, so no early match. Both sites now use isinstance(cls, type) as the filter.

Out of scope (deferred)

  • Vision / audio tower kernel swaps. Gemma 4 loads both via AutoModel.from_config(config.{vision,audio}_config), so the module classes are polymorphic. Out of scope here; FLCE on the LM head is what unblocks training OOM.
  • Gemma4MultimodalEmbedder / projector norms. Analogous to gemma3's mm_soft_emb_norm patching. Skipped for the same minimal-surface reason.
  • PLE (Per-Layer Embeddings). State passes through the inner forward unchanged; verified end-to-end on google/gemma-4-E4B-it.
  • MoE experts (Gemma4TextExperts). Guarded out via the same enable_moe_block check used by the text path in [Gemma 4] Add apply_liger_kernel_to_gemma4_text (dense text, 31B-targeted) #1196.
  • Multimodal mini convergence test (mini_gemma4). Would need new test/resources/fake_configs/Google/Gemma4/... scaffolding for an image / audio processor — PR [Gemma 4] Add apply_liger_kernel_to_gemma4_text (dense text, 31B-targeted) #1196 followed the same pattern (only added mini_gemma4_text to the non-multimodal convergence files). Happy to add it as a follow-up if you'd prefer it bundled here — let me know.

Testing Done

Hardware

  • Hardware Type: RTX 5090 (Blackwell, sm_120, 32 GB), Vast.ai instance
  • CUDA 13.0, NVIDIA driver 590.48.01
  • Python 3.10.12, torch 2.11.0+cu130, transformers 5.7.0.dev0 (built from huggingface/transformers@main — gemma4 requires ≥ 5.5.0)
  • bf16 throughout for end-to-end numerics

End-to-end numerical equivalence on real google/gemma-4-E4B-it

Verified before authoring this PR with our internal verify_patch_equivalence.py (same shape as #1196's verification):

Axis Threshold Measured
Fused-CE vs reference loss diff (seq=512) < 5e-3 0.0016
End-to-end top-1 token agreement (seq=256, patched vs stock SDPA) > 99 % >99 %
End-to-end loss diff (patched vs stock SDPA) < 5e-3 0.0016
Logit-distribution KL (informational) ~3.5e-2
SDPA EFFICIENT_ATTENTION vs MATH cos-sim (seq=4096) > 0.9999 >0.9999

Liger-Kernel test gates

  • make checkstyleAll checks passed!, 267 files already formatted
  • make test — see log
  • make test-convergence — see logs (per file)

make test

3131 passed, 903 skipped, 12 xfailed, 3 failed in 35:42.

Our two new unit tests pass cleanly:

PASSED  test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_gemma4_text                       (0.70s)
PASSED  test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_gemma4_conditional_generation     (0.22s)

All six LigerGEGLUMLPForGemma4 edge-case tests added by #1196 also pass.

The 3 unrelated failures are pre-existing on the parent branch (untouched by this PR) and reproduce on main:

FAILED  test/transformers/test_grpo_loss.py::test_grpo_loss_with_bias_correction_kl[...]
FAILED  test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_paligemma
FAILED  test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation

make test-convergence (per file)

File Result
fp32/test_mini_models.py 32 passed, 3 skipped
fp32/test_mini_models_multimodal.py 8 passed, 4 skipped, 1 xfailed, 1 failed (mini_qwen2_vl)
fp32/test_mini_models_with_logits.py 29 passed, 3 skipped, 1 xpassed
bf16/test_mini_models.py 33 passed, 2 failed (mini_llama4, mini_gemma4_text)
bf16/test_mini_models_multimodal.py 11 passed, 1 skipped, 2 failed (mini_qwen2_vl, mini_llama4)
bf16/test_mini_models_with_logits.py 30 passed, 1 skipped, 2 failed (mini_llama4, mini_qwen3_moe)

mini_gemma4_text passes in bf16/test_mini_models_with_logits.py and fp32/test_mini_models.py (PR #1196's text path). It fails only in bf16/test_mini_models.py with the same Blackwell bf16 logprob drift @eqy and reviewers flagged on #1196 (review comment r4321013177); this is in PR #1196's territory, not introduced by this multimodal patch. The other failures (mini_llama4, mini_qwen2_vl, mini_qwen3_moe) are also in models we don't touch and are pre-existing test instability on consumer Blackwell.

cc @Mecoli1219 @lardinator @ruilin-gif


🤖 Drafted with Claude Code (Claude Opus 4.7), reviewed and posted by me.

@Mecoli1219
Copy link
Copy Markdown
Collaborator

Hi @dvdimitrov13 . Do you have any update on this pr?

Follow-up to linkedin#1196 — adds the multimodal entry point
`apply_liger_kernel_to_gemma4` for `Gemma4ForConditionalGeneration`.

The (B, T, V) bf16 logits tensor on Gemma 4 multimodal training is
~17 GB at T=8192 / vocab=262,144 (and ~34 GB once the loss path upcasts
to fp32), OOMing 96 GB cards on Gemma4ForConditionalGeneration SFT.
Routing loss through `LigerForCausalLMLoss` materializes only the loss
scalar.

Shape — unified entry point dispatching on class (per @Mecoli1219's
preference in linkedin#1186):

- `Gemma4ForConditionalGeneration` → installs `multimodal_forward`,
  class-level RMSNorm + GeGLU swaps, recurses into
  `model.model.language_model` for instance-level patches.
- `Gemma4ForCausalLM` / `Gemma4TextForCausalLM` / `Gemma4TextModel` →
  routes to `apply_liger_kernel_to_gemma4_text`.
- Registry: adds `"gemma4"` alongside existing `"gemma4_text"`.

Drive-by: replaces `cls is not None` with `isinstance(cls, type)` in
`apply_liger_kernel_to_gemma4_text`'s `causal_lm_types` filter. The
dormant bug fires when the multimodal dispatcher recurses into the text
path with a `Gemma4TextModel` — MagicMock (from `unittest.mock.patch`)
auto-creates a `Gemma4TextForCausalLM` attribute that slipped past the
`is not None` filter and landed in an `isinstance(tuple)`, raising
TypeError.

Closes the multimodal half of linkedin#1186.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@dvdimitrov13 dvdimitrov13 force-pushed the feat/gemma4-multimodal branch from 73f941f to d5b0578 Compare May 15, 2026 10:50
@dvdimitrov13 dvdimitrov13 marked this pull request as ready for review May 15, 2026 10:51
@dvdimitrov13
Copy link
Copy Markdown
Contributor Author

Apologies for the silence @Mecoli1219 - relatively new to open source contribution, and I misread the draft convention as "waiting for maintainer pre-review before final polish", which I now realise isn't how it works on external repos. Rebased onto current main as a single squashed commit (the previous 31 commits had drifted off the merge base from before #1196 landed), conflicts resolved, marked ready for review.

Copy link
Copy Markdown
Collaborator

@Mecoli1219 Mecoli1219 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dvdimitrov13, and welcome to the open-source community! Nice first contribution!

Left some inline comments. One bigger-picture thought: we should plan to support Gemma4's vision and audio towers (Gemma4VisionModel / Gemma4AudioModel) in a follow-up. Happy to scope that as a separate PR. Just wanted to flag it so we don't lose track.

# `if config.<m>_config is not None`, so a None-towers model still
# constructs as Gemma4ForConditionalGeneration and exercises the
# multimodal forward we're patching. The towers themselves are
# polymorphic (AutoModel.from_config) and not in this PR's scope.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma4 ships dedicated Gemma4VisionModel & Gemma4AudioModel classes (concrete, not polymorphic like gemma3's SigLIP tower). RMSNorm and the GeGLU/SwiGLU-style MLPs in those towers should be a near-direct port from the text patches — worth doing in a follow-up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will open a follow-up PR after this lands. Scope as I see it: RMSNorm + GeGLU/SwiGLU on Gemma4VisionModel and Gemma4AudioModel, the multimodal projector norms (analogous to gemma3's mm_soft_emb_norm), and the audio convergence test scaffolding I deferred from this PR.

Happy to open a tracking issue first if you'd prefer that over a draft PR.

Comment thread test/utils.py
print("Liger kernel patches have been reverted.")


def revert_liger_kernel_to_gemma4(model_config: MiniModelConfig):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a convergence test for Gemma4 multimodal model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in f23039f. Wanted to be transparent about what came up during validation.

What's in the commit

  • New test/resources/fake_configs/Google/Gemma4/gemma-4-e4b-it/tokenizer_config.json. The chat template emits <image_soft_token> (not <start_of_image> like gemma3): Gemma4Processor.__call__ expands image_token placeholders to <boi><image_token>*n<eoi>, a different pattern from gemma3's boi-token replacement.
  • mini_gemma4 entries in bf16 and fp32 test_mini_models_multimodal.py. Vision config has patch_size=16 (matching Gemma4ImageProcessor's hardcoded default - the processor doesn't accept a patch_size kwarg). audio_config=None, so audio coverage stays in the vision/audio tower follow-up.
  • apply_liger_kernel_to_gemma4 now accepts layer_norm: bool = False as a no-op kwarg. The convergence framework passes layer_norm=True by default for any model not in its exclusion list, and Gemma4 vision uses RMSNorm. Accept-and-no-op is consistent with the deferred vision/audio tower scope.

Validation, asking for guidance

I spun up an RTX 3090 (Ampere) Vast.ai instance to validate locally. After fixing the items above, the test runs end-to-end, but the numerical comparison fails. So does the baseline mini_gemma3 in both bf16 and fp32 on the same env. My read is this is the same env-sensitivity the original PR ran into with mini_gemma4_text on Blackwell bf16.

Env I used: PyTorch 2.5.1 (cuda12.4 image), transformers 5.7.0 from PyPI. I couldn't pull transformers main because it imports CPUOffloadPolicy from torch.distributed.fsdp which requires torch >= 2.6.

A couple of paths forward:

  1. Land as-is and let your CI environment validate (assuming it's the configuration where mini_gemma3 passes today).
  2. Happy to retry on a different GPU + torch + transformers combination if you have a known-green spec.

I'd appreciate your call.

text_classes = tuple(
cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM, Gemma4TextModel) if isinstance(cls, type)
)
if isinstance(model, text_classes):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this path ever happen with _apply_liger_kernel? If not, we can throw error like how gemma3 did.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed - the routing branch is dead via _apply_liger_kernel. Text variants dispatch through the "gemma4_text" registry entry directly, never reaching "gemma4". Matched gemma3's pattern in d0f7516: TypeError on a non-ConditionalGeneration instance.

The isinstance(cls, type) filter in apply_liger_kernel_to_gemma4_text stays - the recursive call from the multimodal path still hits it with a Gemma4TextModel under unittest.mock.patch.

dvdimitrov13 and others added 2 commits May 19, 2026 23:05
Per @Mecoli1219's review on linkedin#1203: the text-routing branch in
apply_liger_kernel_to_gemma4 is dead code via _apply_liger_kernel.
The framework dispatches text variants (Gemma4ForCausalLM /
Gemma4TextForCausalLM / Gemma4TextModel) through the "gemma4_text"
registry entry directly, never reaching the multimodal "gemma4"
entry. Match gemma3's pattern: raise TypeError on a non-
ConditionalGeneration instance.

The drive-by isinstance(cls, type) filter in
apply_liger_kernel_to_gemma4_text stays - the recursive call from
this function still hits the text path with a Gemma4TextModel
instance under unittest.mock.patch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per @Mecoli1219's review on linkedin#1203 (comment on test/utils.py line 508 re:
revert_liger_kernel_to_gemma4): adds the corresponding mini convergence
test for the multimodal entry point, mirroring mini_gemma3's pattern.

Scope: image+text path through Gemma4ForConditionalGeneration.

- New test/resources/fake_configs/Google/Gemma4/gemma-4-e4b-it/
  tokenizer_config.json. Chat template emits <image_soft_token>;
  Gemma4Processor.__call__ expands image_token placeholders to
  <boi><image_token>*n<eoi>, which differs from gemma3's boi-token
  pattern.
- bf16 and fp32 test_mini_models_multimodal.py: GEMMA4_AVAILABLE block,
  MINI_MODEL_SETUPS["mini_gemma4"] with vision config (patch_size=16
  matches Gemma4ImageProcessor's hardcoded default; audio_config=None),
  create_processor branch, pytest.param with gemma3-matching tolerances.
- apply_liger_kernel_to_gemma4 accepts layer_norm: bool = False as a
  no-op kwarg. The convergence framework defaults to passing
  layer_norm=True; Gemma4 vision uses RMSNorm so accept-and-no-op is
  consistent with the deferred vision/audio tower scope.

Audio coverage (Gemma4AudioModel + procedural audio generation in the
convergence harness) is intentionally deferred to the vision/audio
tower follow-up PR; the multimodal forward we patch exercises the LM
head FLCE which is the OOM unblock this PR ships.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@Mecoli1219 Mecoli1219 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dvdimitrov13, I tested locally on H100. There are some numerical drifts that cause failures in the convergence test, but they're within the range we see for other multimodal models in bf16. I think it's fine to merge as-is.

Thanks again for the contribution! If you'd like to take on the vision/audio tower follow-up, feel free to open a PR and ping me when it's ready. Looking forward to your future contributions!

@Mecoli1219 Mecoli1219 added this pull request to the merge queue May 20, 2026
Merged via the queue into linkedin:main with commit decb1b7 May 20, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants